""""Greedy pursuit algorithm for weighted LRA 
(Bhaskara, Ruwanpathirana, Wijewardena, 2021)

https://proceedings.mlr.press/v139/bhaskara21a.html
"""

import numpy as np
from sklearn.utils.extmath import randomized_svd

from . import weighted_linreg


def weighted_lra(matrix, weight, rank, omp=False, eps=1e-8):
  """Weighted low rank approximation using a greedy pursuit.

  If omp is set to True, we re-optimize for the right factor at each iteration.
  Otherwise, we follow the inremental update rule of [BRW2021].
  """
  UT, VT = [], []
  error = np.array(matrix, dtype=np.float64)
  for i in range(rank):
    u  = randomized_svd(error * weight, n_components=1)[0][:, 0]
    UT += [u]
    left_factor = np.transpose(np.array(UT))
    if omp:
      right_factor = weighted_linreg.weighted_linreg(matrix, weight, left_factor)
      error = matrix - left_factor @ right_factor
    else:
      numerator = np.dot(u, error * weight)
      denominator = np.dot(u, weight * u[..., np.newaxis])
      denominator += eps * (np.abs(denominator) < eps)
      coef = np.divide(numerator, denominator)
      assert np.isfinite(coef).all()
      error -= np.outer(u, coef)
      VT += [coef]
  if omp:
    return left_factor, right_factor
  else:
    return left_factor, np.array(VT)
